{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# `CheMeleon` Foundation Finetuning\n",
                "\n",
                "This notebook demonstrates how to use the `CheMeleon` foundation model with Chemprop to achieve accurate prediction on small datasets.\n",
                "One can also use this functionality from the Command Line Interface by using `--from-foundation chemeleon`."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/chemeleon_foundation_finetuning.ipynb)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 1,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Install chemprop from GitHub if running in Google Colab\n",
                "import os\n",
                "\n",
                "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
                "    try:\n",
                "        import chemprop\n",
                "    except ImportError:\n",
                "        !git clone https://github.com/chemprop/chemprop.git\n",
                "        %cd chemprop\n",
                "        !pip install .\n",
                "        %cd examples"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Retrieving the `CheMeleon` Model\n",
                "\n",
                "The `CheMeleon` model file is stored on Zenodo at [this link](https://zenodo.org/records/15426601).\n",
                "Please cite the Zenodo if you use this model in published work.\n",
                "You can manually download for your own use, or simply execute the below cell to programatically download it using Python:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "('chemeleon_mp.pt', <http.client.HTTPMessage at 0x751223631d60>)"
                        ]
                    },
                    "execution_count": 2,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "from urllib.request import urlretrieve\n",
                "\n",
                "urlretrieve(\n",
                "    r\"https://zenodo.org/records/15460715/files/chemeleon_mp.pt\",\n",
                "    \"chemeleon_mp.pt\",\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Initializing `CheMeleon`\n",
                "\n",
                "`CheMeleon` uses the following classes for featurization, message passing, and aggregation:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "<All keys matched successfully>"
                        ]
                    },
                    "execution_count": 3,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "import torch\n",
                "\n",
                "from chemprop import featurizers, nn\n",
                "\n",
                "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n",
                "agg = nn.MeanAggregation()\n",
                "chemeleon_mp = torch.load(\"chemeleon_mp.pt\", weights_only=True)\n",
                "mp = nn.BondMessagePassing(**chemeleon_mp['hyper_parameters'])\n",
                "mp.load_state_dict(chemeleon_mp['state_dict'])"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "If you have existing Chemprop training code you can simply replace your `agg`, `featurizer`, and `mp` with these classes and you can immediately take advantage of `CheMeleon`!\n",
                "\n",
                "In general, we suggest continuing to train the `CheMeleon` weights during finetuning.\n",
                "You may find that in some cases, freezing the weights (`mp.eval()`, `mp.apply(lambda module: module.requires_grad_(False))`) may improve performance."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Standard Chemprop Preparation\n",
                "\n",
                "The below code handles importing needed modules, setting up the data, and initializing the Chemprop model.\n",
                "It's **mostly** the same as the `training` example provided in the Chemprop repository - for a more detailed breakdown, [see the docs here](https://chemprop.readthedocs.io/en/latest/training.html).\n",
                "\n",
                "The one important change is that we must set `input_dim=mp.output_dim` when we initialize our FFN.\n",
                "This ensures that the dimension of the learned representation from `CheMeleon` matches the input size for the regressor.\n",
                "Also important to note here is that to make the `CheMeleon` model useful you set up your own FFN to regress the target you care about - in this case lipophilicity.\n",
                "\n",
                "You can also use `CheMeleon` for classification tasks.\n",
                "See the [regular classification demo notebook](https://chemprop.readthedocs.io/en/latest/training_classification.html), which can be modified as shown above to load `CheMeleon`."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 4,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)\n"
                    ]
                }
            ],
            "source": [
                "from pathlib import Path\n",
                "\n",
                "from lightning import pytorch as pl\n",
                "from lightning.pytorch.callbacks import ModelCheckpoint\n",
                "import pandas as pd\n",
                "\n",
                "from chemprop import data, models\n",
                "\n",
                "chemprop_dir = Path.cwd().parent\n",
                "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\" # path to your data .csv file\n",
                "num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading\n",
                "smiles_column = 'smiles' # name of the column containing SMILES strings\n",
                "target_columns = ['lipo'] # list of names of the columns containing targets\n",
                "df_input = pd.read_csv(input_path)\n",
                "smis = df_input.loc[:, smiles_column].values\n",
                "ys = df_input.loc[:, target_columns].values\n",
                "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]\n",
                "mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits\n",
                "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists\n",
                "train_data, val_data, test_data = data.split_data_by_indices(\n",
                "    all_data, train_indices, val_indices, test_indices\n",
                ")\n",
                "train_dset = data.MoleculeDataset(train_data[0], featurizer)\n",
                "scaler = train_dset.normalize_targets()\n",
                "val_dset = data.MoleculeDataset(val_data[0], featurizer)\n",
                "val_dset.normalize_targets(scaler)\n",
                "test_dset = data.MoleculeDataset(test_data[0], featurizer)\n",
                "train_loader = data.build_dataloader(train_dset, num_workers=num_workers)\n",
                "val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)\n",
                "test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)\n",
                "output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)\n",
                "ffn = nn.RegressionFFN(output_transform=output_transform, input_dim=mp.output_dim)\n",
                "metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()]\n",
                "mpnn = models.MPNN(mp, agg, ffn, batch_norm=False, metrics=metric_list)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Now we can take a look at the model, which we can see has the huge message passing setup from `CheMeleon`:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 5,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "MPNN(\n",
                            "  (message_passing): BondMessagePassing(\n",
                            "    (W_i): Linear(in_features=86, out_features=2048, bias=False)\n",
                            "    (W_h): Linear(in_features=2048, out_features=2048, bias=False)\n",
                            "    (W_o): Linear(in_features=2120, out_features=2048, bias=True)\n",
                            "    (dropout): Dropout(p=0.0, inplace=False)\n",
                            "    (tau): ReLU()\n",
                            "    (V_d_transform): Identity()\n",
                            "    (graph_transform): Identity()\n",
                            "  )\n",
                            "  (agg): MeanAggregation()\n",
                            "  (bn): Identity()\n",
                            "  (predictor): RegressionFFN(\n",
                            "    (ffn): MLP(\n",
                            "      (0): Sequential(\n",
                            "        (0): Linear(in_features=2048, out_features=300, bias=True)\n",
                            "      )\n",
                            "      (1): Sequential(\n",
                            "        (0): ReLU()\n",
                            "        (1): Dropout(p=0.0, inplace=False)\n",
                            "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
                            "      )\n",
                            "    )\n",
                            "    (criterion): MSE(task_weights=[[1.0]])\n",
                            "    (output_transform): UnscaleTransform()\n",
                            "  )\n",
                            "  (X_d_transform): Identity()\n",
                            "  (metrics): ModuleList(\n",
                            "    (0): RMSE(task_weights=[[1.0]])\n",
                            "    (1): MAE(task_weights=[[1.0]])\n",
                            "    (2): MSE(task_weights=[[1.0]])\n",
                            "  )\n",
                            ")"
                        ]
                    },
                    "execution_count": 5,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "mpnn"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Training\n",
                "\n",
                "The remainder of this notebook again follows the typical training routine.\n",
                "With the addition of `CheMeleon` your model may take longer to complete a single epoch due to the increased number of parameters but will (hopefully!) have better performance, particularly if the dataset you have is small, and require fewer epochs to converge!"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 6,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "GPU available: True (cuda), used: True\n",
                        "TPU available: False, using: 0 TPU cores\n",
                        "HPU available: False, using: 0 HPUs\n"
                    ]
                }
            ],
            "source": [
                "# Configure model checkpointing\n",
                "checkpointing = ModelCheckpoint(\n",
                "    \"checkpoints\",  # Directory where model checkpoints will be saved\n",
                "    \"best-{epoch}-{val_loss:.2f}\",  # Filename format for checkpoints, including epoch and validation loss\n",
                "    \"val_loss\",  # Metric used to select the best checkpoint (based on validation loss)\n",
                "    mode=\"min\",  # Save the checkpoint with the lowest validation loss (minimization objective)\n",
                "    save_last=True,  # Always save the most recent checkpoint, even if it's not the best\n",
                ")\n",
                "trainer = pl.Trainer(\n",
                "    logger=False,\n",
                "    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.\n",
                "    enable_progress_bar=True,\n",
                "    accelerator=\"auto\",\n",
                "    devices=1,\n",
                "    max_epochs=20, # number of epochs to train for\n",
                "    callbacks=[checkpointing], # Use the configured checkpoint callback\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 7,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/jackson/chemprop/examples/checkpoints exists and is not empty.\n",
                        "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
                        "Loading `train_dataloader` to estimate number of stepping batches.\n",
                        "/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
                        "\n",
                        "  | Name            | Type               | Params | Mode \n",
                        "---------------------------------------------------------------\n",
                        "0 | message_passing | BondMessagePassing | 8.7 M  | train\n",
                        "1 | agg             | MeanAggregation    | 0      | train\n",
                        "2 | bn              | Identity           | 0      | train\n",
                        "3 | predictor       | RegressionFFN      | 615 K  | train\n",
                        "4 | X_d_transform   | Identity           | 0      | train\n",
                        "5 | metrics         | ModuleList         | 0      | train\n",
                        "---------------------------------------------------------------\n",
                        "9.3 M     Trainable params\n",
                        "0         Non-trainable params\n",
                        "9.3 M     Total params\n",
                        "37.317    Total estimated model params size (MB)\n",
                        "25        Modules in train mode\n",
                        "0         Modules in eval mode\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]"
                    ]
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  9.11it/s, train_loss_step=0.0298, val_loss=0.548, train_loss_epoch=0.0245]"
                    ]
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  4.84it/s, train_loss_step=0.0298, val_loss=0.548, train_loss_epoch=0.0245]\n"
                    ]
                }
            ],
            "source": [
                "trainer.fit(mpnn, train_loader, val_loader)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 8,
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:145: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n",
                        "Restoring states from the checkpoint path at /home/jackson/chemprop/examples/checkpoints/best-epoch=16-val_loss=0.49.ckpt\n",
                        "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
                        "Loaded model weights from the checkpoint at /home/jackson/chemprop/examples/checkpoints/best-epoch=16-val_loss=0.49.ckpt\n",
                        "/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 87.82it/s] \n"
                    ]
                },
                {
                    "data": {
                        "text/html": [
                            "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
                            "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
                            "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
                            "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mae          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.5628008842468262     </span>│\n",
                            "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/rmse         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6701957583427429     </span>│\n",
                            "└───────────────────────────┴───────────────────────────┘\n",
                            "</pre>\n"
                        ],
                        "text/plain": [
                            "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
                            "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
                            "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
                            "│\u001b[36m \u001b[0m\u001b[36m        test/mae         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.5628008842468262    \u001b[0m\u001b[35m \u001b[0m│\n",
                            "│\u001b[36m \u001b[0m\u001b[36m        test/rmse        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6701957583427429    \u001b[0m\u001b[35m \u001b[0m│\n",
                            "└───────────────────────────┴───────────────────────────┘\n"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                }
            ],
            "source": [
                "results = trainer.test(dataloaders=test_loader)"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "chemprop_dev",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.12.9"
        },
        "orig_nbformat": 4
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
